Skip to content

Enable training strategy for Indexer#3415

Merged
copybara-service[bot] merged 1 commit intomainfrom
indexer_train_strategy
Mar 20, 2026
Merged

Enable training strategy for Indexer#3415
copybara-service[bot] merged 1 commit intomainfrom
indexer_train_strategy

Conversation

@RissyRan
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan commented Mar 14, 2026

Description

Enable selective parameter training strategy for DeeSeek V3.2 Indexer - paper

  • Dense warm up stage:
    • Add trainable_parameters_mask flag, allowing specific parameters to be targeted for training while freezing the rest of the model.
    • Add TrainableParametersMaskTest unit tests for validation.
  • Sparse training stage:
    • Update indexer_sparse_training flag to indicate Dense Warm-up stage or Sparse Training stage for DS v3.2.
    • Add test_indexer_gradients unit test to verify proper gradient isolation.
  • Renaming flags to avoid confusion
    • use_sparse_indexer --> use_indexer; index_head_dim --> indexer_head_dim; index_n_heads --> indexer_n_heads, and index_topk --> indexer_topk

Tests

  • Expect added unit tests are all green
  • End-to-end functional - logs
  • Sanity check deepseek32_vs_reference_test - link, same as b/491486716

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 14, 2026

Codecov Report

❌ Patch coverage is 75.40984% with 15 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 25.00% 11 Missing and 1 partial ⚠️
src/maxtext/optimizers/optimizers.py 89.47% 2 Missing ⚠️
src/maxtext/layers/attention_op.py 0.00% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@RissyRan RissyRan changed the title Enable selective parameter training strategy Enable training strategy for Indexer Mar 14, 2026
@RissyRan RissyRan force-pushed the indexer_train_strategy branch from 85e4e0d to b0f353b Compare March 14, 2026 02:20
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR enables the selective parameter training strategy (dense warm-up and sparse training stages) for the DeepSeek V3.2 Indexer. It refactors parameter freezing flags and adds tests to verify proper isolation of indexer gradients from the rest of the model.

🔍 General Feedback

  • Memory Optimization in Selective Training: The current implementation of optimizer masking computes and stores Adam state parameters for the entire model before zeroing out the updates. I've suggested an explicit mapping with optax.multi_transform to avoid allocating massive memory blocks for frozen parameter states, which is critical for 671B model scaling.
  • Gradient Isolation in KL Divergence: I left an inline comment pointing out a gradient leak when calculating the KL divergence in calculate_indexer_loss. Ensure jax.lax.stop_gradient is applied to the target attention_probs distribution, so that the main model's queries and keys do not get updated by the indexer's loss.

Comment thread src/maxtext/optimizers/optimizers.py
Comment thread src/maxtext/layers/attention_mla.py Outdated
@RissyRan RissyRan force-pushed the indexer_train_strategy branch 4 times, most recently from 4ec8a1e to 906f12f Compare March 14, 2026 03:18
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sparse training looks good to me. I left a suggestion on dense warmup, along with minor comments. Will take another look at trainable_parameters_mask soon.

Comment thread src/maxtext/layers/attention_mla.py Outdated
Comment thread tests/unit/deepseek32_vs_reference_test.py Outdated
Comment thread tests/unit/flop_calculation_test.py
Comment thread tests/unit/train_compile_test.py
Comment thread src/maxtext/configs/base.yml
Comment thread src/maxtext/layers/attention_mla.py Outdated
Comment thread src/maxtext/layers/attention_mla.py Outdated
@RissyRan RissyRan force-pushed the indexer_train_strategy branch 8 times, most recently from be3bad4 to d3defd7 Compare March 20, 2026 05:07
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for the change.

Comment thread tests/unit/deepseek32_vs_reference_test.py Outdated
Copy link
Copy Markdown
Collaborator

@Rohan-Bierneni Rohan-Bierneni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes I have left a small nit.

Comment thread src/maxtext/optimizers/optimizers.py Outdated
@RissyRan RissyRan force-pushed the indexer_train_strategy branch 2 times, most recently from f27e777 to fef4e96 Compare March 20, 2026 20:34
@RissyRan RissyRan force-pushed the indexer_train_strategy branch from fef4e96 to 0b55a28 Compare March 20, 2026 20:36
@copybara-service copybara-service Bot merged commit cc0d3ae into main Mar 20, 2026
31 checks passed
@copybara-service copybara-service Bot deleted the indexer_train_strategy branch March 20, 2026 22:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants